package drds.plus.sql_process.rule;

import com.google.common.collect.Lists;
import drds.plus.common.lifecycle.AbstractLifecycle;
import drds.plus.common.model.comparative.List;
import drds.plus.common.model.comparative.*;
import drds.plus.rule_engine.Route;
import drds.plus.rule_engine.rule_calculate.DataNodeDataScatterInfo;
import drds.plus.rule_engine.rule_calculate.RuleCalculateDiffrentException;
import drds.plus.rule_engine.rule_calculate.RuleCalculateResult;
import drds.plus.rule_engine.table_rule.TableRule;
import drds.plus.sql_process.abstract_syntax_tree.ObjectCreateFactory;
import drds.plus.sql_process.abstract_syntax_tree.expression.NullValue;
import drds.plus.sql_process.abstract_syntax_tree.expression.bind_value.BindValue;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.column.Column;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.*;
import drds.plus.sql_process.type.Type;
import drds.plus.sql_process.type.Types;
import drds.plus.sql_process.utils.OptimizerUtils;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;

import java.util.*;

/**
 * 优化器中使用Tddl Rule的一些工具方法，需要依赖{@linkplain Route}自己先做好初始化
 */
@Slf4j
public class RouteOptimizer extends AbstractLifecycle {
    @Override
    public Logger log() {
        return log;
    }

    private final static int DEFAULT_OPERATION_COMP = -1000;
    private final Route route;

    public RouteOptimizer(Route route) {
        this.route = route;
    }

    /**
     * 将一个{@linkplain Filter}表达式转化为Tddl Rule所需要的{@linkplain Comparative}对象
     */
    public static Comparative getComparative(Filter filter, String columnName) {
        // 前序遍历，找到所有符合要求的条件
        if (filter == null) {
            return null;
        }

        if ("NOT".equalsIgnoreCase(filter.getFunctionName())) {
            return null;
        }

        if (filter instanceof LogicalOperationFilter) {
            if (filter.isNot()) {
                return null;
            }

            List list = null;
            LogicalOperationFilter logicalOperationFilter = (LogicalOperationFilter) filter;
            switch (filter.getOperation()) {
                case and:
                    list = new And();
                    break;
                case or:
                    list = new Or();
                    break;
                default:
                    throw new IllegalArgumentException();
            }

            boolean isExistInAllSubFilter = true;
            for (Filter subFilter : logicalOperationFilter.getFilterList()) {
                Comparative subComparative = getComparative(subFilter, columnName);// 递归
                if (subComparative != null) {
                    list.addComparative(subComparative);
                }
                isExistInAllSubFilter &= (subComparative != null);
            }

            if (list == null || list.getComparativeList() == null || list.getComparativeList().isEmpty()) {
                return null;
            } else if (list instanceof Or && !isExistInAllSubFilter) {
                // 针对or类型，必须所有的子条件都包含该列条件，否则就是一个全库扫描，返回null值
                // 比如分库键为id，如果条件是 id = 1 or id = 3，可以返回
                // 如果条件是id = 1 or columnName = 2，应该是个全表扫描
                return null;
            } else if (list.getComparativeList().size() == 1) {
                return list.getComparativeList().get(0);// 可能只有自己一个and
            }
            return list;
        } else if (filter instanceof OrsFilter) {
            if (filter.isNot()) {
                return null;
            }

            OrsFilter orsFilter = (OrsFilter) filter;
            List or = new Or();
            boolean isExistInAllSubFilter = true;
            for (Filter subFilter : orsFilter.getFilterList()) {
                Comparative subComparative = getComparative(subFilter, columnName);// 递归
                if (subComparative != null) {
                    or.addComparative(subComparative);
                }
                isExistInAllSubFilter &= (subComparative != null);
            }
            if (or == null || or.getComparativeList() == null || or.getComparativeList().isEmpty()) {
                return null;
            } else if (or instanceof Or && !isExistInAllSubFilter) {
                // 针对or类型，必须所有的子条件都包含该列条件，否则就是一个全库扫描，返回null值
                // 比如分库键为id，如果条件是 id = 1 or id = 3，可以返回
                // 如果条件是id = 1 or columnName = 2，应该是个全表扫描
                return null;
            } else if (or.getComparativeList().size() == 1) {
                return or.getComparativeList().get(0);// 可能只有自己一个and
            }

            return or;
        } else if (filter instanceof BooleanFilter) {
            Comparative comparative = null;
            BooleanFilter booleanFilter = (BooleanFilter) filter;
            // 判断非空
            if (isNull(booleanFilter.getColumn()) || (isNull(booleanFilter.getValue()) && isNull(booleanFilter.getValueList()))) {
                return null;
            }

            Object column = convertNowFunction(booleanFilter.getColumn());
            Object value = convertNowFunction(booleanFilter.getValue());
            // 判断是否为 A > B , A > B + 1
            if (column instanceof Item && value instanceof Item) {
                return null;
            }

            // 必须要有一个是字段
            if (!(column instanceof Column || value instanceof Column)) {
                return null;
            }

            if (booleanFilter.isNot()) {
                return null;
            }

            if (booleanFilter.getOperation() == Operation.in) {// in不能出现isReverse
                Type type = null;
                if (booleanFilter.getColumn() instanceof Column) {
                    type = ((Column) booleanFilter.getColumn()).getType();
                }

                List or = new Or();
                for (Object object : booleanFilter.getValueList()) {
                    BooleanFilter equal = ObjectCreateFactory.createBooleanFilter();
                    equal.setOperation(Operation.equal);
                    equal.setColumn(booleanFilter.getColumn());
                    equal.setValue(getValue(object, type));

                    Comparative subComparative = getComparative(equal, columnName);
                    if (subComparative != null) {
                        or.addComparative(subComparative);
                    }
                }

                if (or.getComparativeList().isEmpty()) {// 所有都被过滤
                    return null;
                }

                return or;
            } else {
                int operationComp = DEFAULT_OPERATION_COMP;
                switch (booleanFilter.getOperation()) {
                    case greater_than:
                        operationComp = Comparative.greater_than;
                        break;
                    case equal:
                        operationComp = Comparative.equal;
                        break;
                    case greater_than_or_equal:
                        operationComp = Comparative.greater_than_or_equal;
                        break;
                    case less_than:
                        operationComp = Comparative.less_than;
                        break;
                    case less_than_or_equal:
                        operationComp = Comparative.less_than_or_equal;
                        break;
                    default:
                        return null;
                }

                Column column1 = null;
                Object value1 = null;
                if (booleanFilter.getColumn() instanceof Column) {
                    column1 = OptimizerUtils.getColumn(column);
                    value1 = getValue(value, column1.getType());
                } else {// 出现 1 = id 的写法
                    column1 = OptimizerUtils.getColumn(value);
                    value1 = getValue(column, column1.getType());
                    operationComp = Comparative.exchangeComparison(operationComp); // 反转一下
                }

                if (columnName.equalsIgnoreCase(column1.getColumnName()) && operationComp != DEFAULT_OPERATION_COMP) {
                    if (!(value instanceof Comparable)) {
                        throw new RuntimeException("type: " + value.getClass().getSimpleName() + " is not comparable, cannot be used in partition column");
                    }
                    comparative = new Comparative(operationComp, (Comparable) value1);
                }

                return comparative;
            }
        } else {
            // 为null,全表扫描
            return null;
        }
    }

    private static boolean isNull(Object val) {
        if (val == null) {
            return true;
        } else if (val instanceof NullValue) {
            return true;
        } else if (val instanceof Collection) {
            boolean isn = true;
            for (Object obj : (Collection) val) {
                isn &= isNull(obj);
            }

            return isn;
        }

        return false;
    }

    /**
     * 尝试将now函数转化为date对象
     */
    private static Object convertNowFunction(Object val) {
        if (val instanceof Function) {
            Function function = (Function) val;
            if ("NOW".equalsIgnoreCase(function.getFunctionName())) {
                return new Date();
            }
        }

        return val;
    }

    private static Object getValue(Object value, Type type) {
        if (value instanceof BindValue) {
            // 针对batch时，不替换BindVal
            return ((BindValue) value).getValue();
        } else if (Types.isDateType(type)) {
            // 针对时间类型，进行一次编码转化
            return type.convert(value);
        }

        return value;
    }


    protected void doInit() {
        if (route != null && !route.isInited()) {
            route.init();
        }
    }

    protected void doDestroy() {
        if (route != null && route.isInited()) {
            route.destroy();
        }
    }

    public java.util.List route(String logicTable, ColumnNameToComparativeMapChoicer columnNameToComparativeMapChoicer, boolean isWrite) {
        RuleCalculateResult ruleCalculateResult = route.routes(!isWrite, logicTable, columnNameToComparativeMapChoicer, Lists.newArrayList());
        java.util.List<DataNodeDataScatterInfo> dataNodeDataScatterInfoList = ruleCalculateResult.getDataNodeDataScatterInfoList();
        if (dataNodeDataScatterInfoList == null || dataNodeDataScatterInfoList.isEmpty()) {
            throw new IllegalArgumentException("can't find target db. where is " + logicTable + ".");
        }

        return dataNodeDataScatterInfoList;
    }

    /**
     * 根据逻辑表和条件，计算一下目标库
     */
    public java.util.List route(String logicTable, final Filter filter, boolean isWrite, boolean forceAllowFullTableScan) {
        RuleCalculateResult ruleCalculateResult;
        try {
            ruleCalculateResult = route.routes(!isWrite, logicTable, new ColumnNameToComparativeMapChoicer() {

                public Map<String, Comparative> getColumnNameToComparativeMap(Set<String> columnNameSet, java.util.List<Object> argumentList) {
                    Map<String, Comparative> columnNameToComparativeMap = new HashMap<String, Comparative>();
                    for (String columnName : columnNameSet) {
                        columnNameToComparativeMap.put(columnName, getColumnComparative(columnName, argumentList));
                    }

                    return columnNameToComparativeMap;
                }

                public Comparative getColumnComparative(String columnName, java.util.List argumentList) {
                    return getComparative(filter, columnName);
                }
            }, Lists.newArrayList(), forceAllowFullTableScan);
        } catch (RuleCalculateDiffrentException e) {
            throw e;
        }

        java.util.List<DataNodeDataScatterInfo> dataNodeDataScatterInfoList = ruleCalculateResult.getDataNodeDataScatterInfoList();
        if (dataNodeDataScatterInfoList == null || dataNodeDataScatterInfoList.isEmpty()) {
            throw new IllegalArgumentException("can't find target db. where is " + logicTable + ". filter is " + filter);
        }

        return dataNodeDataScatterInfoList;
    }

    /**
     * 根据逻辑表返回一个随机的物理目标库TargetDB
     */
    public DataNodeDataScatterInfo shardAny(String logicTable) {
        TableRule tableRule = getTableRule(logicTable);
        if (tableRule == null) {
            // 设置为同名，同名不做转化
            DataNodeDataScatterInfo dataNodeDataScatterInfo = new DataNodeDataScatterInfo();
            dataNodeDataScatterInfo.setDataNodeId(getDefaultDataNodeId(logicTable));
            dataNodeDataScatterInfo.addTable(logicTable);
            return dataNodeDataScatterInfo;
        } else {
            for (String dataNodeId : tableRule.getActualTopology().keySet()) {
                Set<String> tableNames = tableRule.getActualTopology().get(dataNodeId);
                if (tableNames == null || tableNames.isEmpty()) {
                    continue;
                }

                DataNodeDataScatterInfo dataNodeDataScatterInfo = new DataNodeDataScatterInfo();
                dataNodeDataScatterInfo.setDataNodeId(dataNodeId);
                dataNodeDataScatterInfo.addTable(tableNames.iterator().next());
                return dataNodeDataScatterInfo;
            }
        }
        throw new IllegalArgumentException("can't find any target db. where is " + logicTable + ". ");
    }

    /**
     * 判断一下逻辑表是否是一张物理单库单表
     */
    public boolean isTableInSingleDb(String logicTable) {
        TableRule tableRule = getTableRule(logicTable);
        // 判断是否是单库单表
        return tableRule == null;

    }

    public String getDefaultDataNodeId(String logicTable) {
        String defaultDataNodeId = route.getDefaultDataNodeId(logicTable);
        if (defaultDataNodeId == null) {
            throw new RuntimeException("defaultDataNodeId is null");
        }
        return defaultDataNodeId;
    }

    public String getJoinDataNodeId(String logicTable) {
        TableRule tableRule = getTableRule(logicTable);
        return tableRule != null ? tableRule.getJoinDataNodeId() : null;// 没找到表规则，默认为单库
    }

    public boolean isBroadCast(String logicTable) {
        TableRule table = getTableRule(logicTable);
        return table != null && table.isBroadcast();// 没找到表规则，默认为单库，所以不是广播表
    }

    public java.util.List getShardColumnNameList(String logicTable) {
        TableRule tableRule = getTableRule(logicTable);
        return tableRule != null ? tableRule.getShardColumnNameList() : new ArrayList<String>();// 没找到表规则，默认为单库
    }

    public TableRule getTableRule(String logicTable) {
        return route.getTableRule(logicTable);
    }

    /**
     * 将defaultDb上的表和规则中的表做一次合并
     */
    public Set<String> mergeTableRule(java.util.List<String> defaultDbTables) {
        Set<String> result = new HashSet<String>();
        java.util.List<TableRule> tableRules = route.getTableRuleList();
        Map<String, String> dbIndexMap = route.getDbIndexMap();
        String defaultDb = route.getDefaultDataNodeId();
        // // 添加下分库分表数据
        for (TableRule tableRule : tableRules) {
            String table = tableRule.getVirtualTbName();
            // 针对二级索引的表名，不加入到tables中
            if (!table.contains("._")) {
                result.add(table);
            }
        }
        for (Map.Entry<String, String> entry : dbIndexMap.entrySet()) {
            if (!entry.getValue().equals(defaultDb)) {
                // 针对二级索引的表名，不加入到tables中
                if (!entry.getKey().contains("._")) {
                    result.add(entry.getKey());
                }
            }
        }
        // 过滤掉分库分表
        for (String table : defaultDbTables) {
            boolean found = false;
            for (TableRule tableRule : tableRules) {
                if (tableRule.isActualTable(table)) {
                    found = true;
                    break;
                }
            }

            if (dbIndexMap.containsKey(table)) {
                found = true;
            }

            if (!found) {
                result.add(table);
            }
        }
        return result;
    }

}
